Skip to content

[Transform] [Utils] Support precision, add torch dtype validation #414

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 11, 2025

Conversation

kylesayrs
Copy link
Contributor

Purpose

  • Support configuring the precision at which transforms are applied, which seems to have some minor effects on results
  • Add TorchDtype type annotation for adding torch dtypes to model definitions

Changes

  • Added precision argument to TransformSchemes
    • Transform weights are constructed using this precision
    • Transform weights are applied using this precision
      • This precision is used for both fusing operations and online transforms
  • Added TorchDtype type annotation in src/utils/type.py
    • Supports loading from torch.xxx or xxx strings and torch.dtypes

Testing

  • Added tests for TorchDtype type annotation
  • Tested with different precisions and found torch.float32 to be acceptable

kylesayrs

This comment was marked as outdated.

Copy link
Collaborator

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a few questions - lgtm otherwise

Signed-off-by: Kyle Sayers <[email protected]>
Copy link
Contributor

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates to merged precision!

@kylesayrs kylesayrs merged commit 131673e into main Aug 11, 2025
1 check passed
@kylesayrs kylesayrs deleted the kylesayrs/transform-precision branch August 11, 2025 15:18
brian-dellabetta added a commit to vllm-project/llm-compressor that referenced this pull request Aug 13, 2025
## Purpose ##
* Enable offline spinquant-style transforms

## Prerequisites ##
* neuralmagic/compressed-tensors#370
* neuralmagic/compressed-tensors#412
* neuralmagic/compressed-tensors#414

## Changes ##
* Added `spinquant_example.py` to examples folder
* Added `SpinQuantModifier` which handles the construction of a
spinquant-style transform config

## Testing ##
* Added modifier serialization and correctness tests

## Evaluation ##
Using this branch, and [the original SpinQuant
code](https://github.com/facebookresearch/SpinQuant), we see very
similar results for `meta-llama/Llama-3.2-1B-Instruct` with W4A16
quantization. Results are equivalent in hf (in-memory vs serialized and
re-loaded), and very similar in vllm. The symmetric scales calculation
in `llm-compressor` is slightly different than original SpinQuant paper,
which uses the original GPTQ implementation. When this is swapped in,
results are consistent, with hadamard improving results on `gsm8k_llama`
and `arc_challenge_llama`:

Scheme | Impl | gsm8k | gsm8k_llama | arc_challenge_llama
-- | -- | -- | -- | --
Hadamard+W4A16 | LC | 0.2403 | 0.2835 | 0.5262
W4A16 | LC | 0.1964 | 0.1933 | 0.4781
Hadamard+W4A16 | LC+SQscales | 0.1721 | 0.2183 | 0.485
W4A16 | LC+SQscales | 0.207 | 0.1706 | 0.4498
Hadamard+W4A16 | SQ | 0.1736 | 0.2282 | 0.4807
W4A16 | SQ | 0.1986 | 0.1774 | 0.4489

To run LC+SQScales, change [this line in
CT](https://github.com/neuralmagic/compressed-tensors/blob/b2df366797b00330ec765f5891dde14e4cc74c9d/src/compressed_tensors/quantization/utils/helpers.py#L111)
from

```python
scales = max_val_pos / (float(bit_range) / 2)
```
to
```python
scales = max_val_pos / (float(bit_max))
```

<details>
<summary>The following python script was used to generate these
results</summary>

Clone SpinQuant repo and paste this in the top-level directory:
```python
# coding=utf-8
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
from typing import Literal
import os

os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

from torch import nn
import lm_eval

from transformers import LlamaForCausalLM, AutoTokenizer
import transformers
from train_utils.main import prepare_model
from train_utils.modeling_llama_quant import LlamaForCausalLM as LlamaForCausalLMQuant
from utils.hadamard_utils import random_hadamard_matrix, hadamard_matrix
from utils.process_args import process_args_ptq

# model_id = "meta-llama/Llama-3.1-8B-Instruct"
# model_id = "meta-llama/Llama-3.2-3B-Instruct"
model_id = "meta-llama/Llama-3.2-1B-Instruct"
dtype = torch.bfloat16


class RotateModule(nn.Module):
    def __init__(self, R_init):
        super(RotateModule, self).__init__()
        self.weight = nn.Parameter(R_init.to(torch.float32).to(torch.device("cuda")))

    def forward(self, x, transpose=False):
        if transpose:
            return x @ self.weight
        else:
            return self.weight @ x


def get_sq_model(
    r1r2=Literal["eye", "random-hadamard", "hadamard"],
    w_bits=Literal[4, 16],
    w_clip: bool = False,
) -> LlamaForCausalLMQuant:
    model_args, training_args, ptq_args = process_args_ptq()
    model_args.input_model = model_id
    if w_bits == 4:
        ptq_args.w_bits = 4
        ptq_args.w_groupsize = 128
        ptq_args.w_rtn = True  # if False, GPTQ is used
        ptq_args.w_clip = w_clip
    ptq_args.a_bits = 16
    ptq_args.k_bits = 16
    ptq_args.v_bits = 16

    print("=======ARGS=======", ptq_args)

    config = transformers.AutoConfig.from_pretrained(model_args.input_model)

    # Llama v3.2 specific: Spinquant is not compatiable with tie_word_embeddings, clone lm_head from embed_tokens
    process_word_embeddings = False
    if config.tie_word_embeddings:
        config.tie_word_embeddings = False
        process_word_embeddings = True

    model = LlamaForCausalLMQuant.from_pretrained(
        pretrained_model_name_or_path=model_args.input_model,
        config=config,
        torch_dtype=dtype,
        device_map="cuda",
    )

    if process_word_embeddings:
        model.lm_head.weight.data = model.model.embed_tokens.weight.data.clone()

    model = prepare_model(ptq_args, model)
    for param in model.parameters():
        param.requires_grad = False
    match r1r2:
        case "eye":
            R1 = torch.eye(model.config.hidden_size, device="cuda")
        case "random-hadamard":
            R1 = random_hadamard_matrix(model.config.hidden_size, "cuda")
        case _:
            R1 = hadamard_matrix(model.config.hidden_size, "cuda")
    model.R1 = RotateModule(R1)
    for i in range(model.config.num_hidden_layers):
        # Each head dim = 128 for Llama model
        match r1r2:
            case "eye":
                R2 = torch.eye(
                    model.config.hidden_size // model.config.num_attention_heads,
                    device="cuda",
                )
            case "random-hadamard":
                R2 = random_hadamard_matrix(
                    model.config.hidden_size // model.config.num_attention_heads, "cuda"
                )
            case _:
                R2 = hadamard_matrix(
                    model.config.hidden_size // model.config.num_attention_heads, "cuda"
                )
        model.model.layers[i].self_attn.R2 = RotateModule(R2)

    model.config.use_cache = False

    return model


def get_lc_model(
    r1r2=Literal["eye", "random-hadamard", "hadamard"], w_bits=Literal[4, 16]
) -> LlamaForCausalLM:
    from llmcompressor import oneshot
    from llmcompressor.modifiers.quantization import QuantizationModifier
    from llmcompressor.modifiers.transform import SpinQuantModifier

    model = LlamaForCausalLM.from_pretrained(
        pretrained_model_name_or_path=model_id,
        torch_dtype=dtype,
        device_map="cuda",
    )

    recipe = [
        SpinQuantModifier(
            rotations=[] if r1r2 == "eye" else ["R1", "R2"],
            transform_type="hadamard",
        )
    ]
    if w_bits == 4:
        recipe.append(
            QuantizationModifier(
                targets="Linear",
                scheme="W4A16",
                ignore=["lm_head"],
            )
        )

    oneshot(
        model=model,
        recipe=recipe,
        pipeline="datafree",
        log_dir=None,
    )

    return model


if __name__ == "__main__":
    for scales_impl in ["sq_min_hack", "lc_min_hack"]:
        for r1r2 in ["eye", "hadamard"]:
            for sq_lc in ["sq", "lc"]:
                w_bits = 4

                os.environ["SCALES_IMPL"] = scales_impl

                model = (
                    get_sq_model(r1r2=r1r2, w_bits=w_bits)
                    if sq_lc == "sq"
                    else get_lc_model(r1r2=r1r2, w_bits=w_bits)
                ).to("cuda")

                SAVE_DIR = model_id.split("/")[1] + f"-{scales_impl}-{r1r2}-w4a16"
                model.save_pretrained(SAVE_DIR, save_compressed=True)
                tokenizer = AutoTokenizer.from_pretrained(
                    model_id, trust_remote_code=True
                )
                tokenizer.save_pretrained(SAVE_DIR)

                del model
                del tokenizer
                torch.cuda.empty_cache()

                results = lm_eval.simple_evaluate(
                    # 1) hf in-memory
                    # model=lm_eval.models.huggingface.HFLM(
                    #     pretrained=model,
                    #     batch_size=32,
                    #     add_bos_token=False,
                    # ),
                    # 1/)
                    # 2) vllm serialized
                    model="vllm",
                    model_args={
                        "pretrained": SAVE_DIR,
                        "add_bos_token": False,
                        "dtype": "auto",
                        "max_model_len": 4096,
                        "gpu_memory_utilization": 0.5,
                        "enable_chunked_prefill": True,
                    },
                    # 2/)
                    # 3) hf serialized
                    # model="hf",
                    # model_args={
                    #     "pretrained": SAVE_DIR,
                    #     "add_bos_token": False,
                    #     "dtype": "auto",
                    # },
                    # device="cuda",
                    # 3/)
                    tasks=["gsm8k_llama", "gsm8k", "arc_challenge_llama"],
                    num_fewshot=8,
                    batch_size=32,
                    apply_chat_template=True,
                    fewshot_as_multiturn=True,
                )
                print(
                    f"RESULTS, {model_id} {sq_lc} R1R2 {r1r2} W_BITS {w_bits} SCALEIMPL {scales_impl}"
                )
                print(lm_eval.utils.make_table(results))
```
</details>


## Follow Ups ##
* Infer data free pipeline, even if a transform modifier is included
* Rotations R3 and R4
* Modify example to use GPTQ once basic evaluation has been performed

---------

Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Co-authored-by: Kyle Sayers <[email protected]>
dsikka added a commit to vllm-project/llm-compressor that referenced this pull request Aug 14, 2025
## Purpose ##
* Enable quip-style transforms

## Prerequisites ##
* neuralmagic/compressed-tensors#370
* neuralmagic/compressed-tensors#412
* neuralmagic/compressed-tensors#414

## Changes ##
* Added `quip_example.py` to examples folder
* As made clear in the disclaimer, this example requires minimum
versions of compressed-tensors and transformers to run
* Added `QuIPModifier` which handles the construction of a quip-style
transform config

## Testing ##
* Added modifier serialization and correctness tests

## Evaluation ##
Evaluation performed by @brian-dellabetta 

Evals on Llama 3.2 1B with Quip (num_fewshot 8, limit 1000 to be
compatible with results
[here](https://github.com/vllm-project/llm-compressor/pull/1243/files#diff-bdc27f23c0dc2da352d5c83abdc0f267873edf4d36f88474038b975df75bd8c3R38-R64))
:

| Strat | gsm8k,strict | gsm8k_llama,strict |
|-|-|-|
| FP16 | .352 | .323 |
| Quip | .348 | .322 |
| W4A16 | .180 | .017 |
| Quip+W4A16 | .213 | .141 |

## Follow Ups ##
* Infer data free pipeline, even if a transform modifier is included
* Modify example to use GPTQ once basic evaluation has been performed

---------

Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Co-authored-by: Brian Dellabetta <[email protected]>
Co-authored-by: Dipika Sikka <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants